{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Minimal tutorial on packing and unpacking sequences in PyTorch, aka how to use `pack_padded_sequence` and  `pad_packed_sequence`\n",
    "\n",
    "This is a jupyter version of [@Tushar-N 's gist](https://gist.github.com/Tushar-N/dfca335e370a2bc3bc79876e6270099e) with comments from [@Harsh Trivedi repo](https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial\n",
    "import torch\n",
    "from torch import LongTensor\n",
    "from torch.nn import Embedding, LSTM\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "\n",
    "## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']\n",
    "#\n",
    "#     Step 1: Construct Vocabulary\n",
    "#     Step 2: Load indexed data (list of instances, where each instance is list of character indices)\n",
    "#     Step 3: Make Model\n",
    "#  *  Step 4: Pad instances with 0s till max length sequence\n",
    "#  *  Step 5: Sort instances by sequence length in descending order\n",
    "#  *  Step 6: Embed the instances\n",
    "#  *  Step 7: Call pack_padded_sequence with embeded instances and sequence lengths\n",
    "#  *  Step 8: Forward with LSTM\n",
    "#  *  Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector\n",
    "#  *  Summary of Shape Transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We want to run LSTM on a batch following 3 character sequences\n",
    "seqs = ['long_str',  # len = 8\n",
    "        'tiny',      # len = 4\n",
    "        'medium']    # len = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 1: Construct Vocabulary ##\n",
    "##------------------------------##\n",
    "# make sure <pad> idx is 0\n",
    "vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))\n",
    "# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ##\n",
    "##-------------------------------------------------------------------------------------------------##\n",
    "vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]\n",
    "# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10],\n",
    "#                     [12, 5, 8, 14],\n",
    "#                     [7, 3, 2, 5, 13, 7]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vectorized_seqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 3: Make Model ##\n",
    "##--------------------##\n",
    "embed = Embedding(len(vocab), 4) # embedding_dim = 4\n",
    "lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 4: Pad instances with 0s till max length sequence ##\n",
    "##--------------------------------------------------------##\n",
    "\n",
    "# get the length of each seq in your batch\n",
    "seq_lengths = LongTensor(list(map(len, vectorized_seqs)))\n",
    "# seq_lengths => [ 8, 4,  6]\n",
    "# batch_sum_seq_len: 8 + 4 + 6 = 18\n",
    "# max_seq_len: 8\n",
    "\n",
    "seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()\n",
    "# seq_tensor => [[0 0 0 0 0 0 0 0]\n",
    "#                [0 0 0 0 0 0 0 0]\n",
    "#                [0 0 0 0 0 0 0 0]]\n",
    "\n",
    "for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):\n",
    "    seq_tensor[idx, :seqlen] = LongTensor(seq)\n",
    "# seq_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str\n",
    "#                [12  5  8 14  0  0  0  0]          # tiny\n",
    "#                [ 7  3  2  5 13  7  0  0]]         # medium\n",
    "# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([8, 4, 6])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],\n",
       "        [12,  5,  8, 14,  0,  0,  0,  0],\n",
       "        [ 7,  3,  2,  5, 13,  7,  0,  0]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 5: Sort instances by sequence length in descending order ##\n",
    "##---------------------------------------------------------------##\n",
    "\n",
    "seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)\n",
    "seq_tensor = seq_tensor[perm_idx]\n",
    "# seq_tensor => [[ 6  9  8  4  1 11 12 10]           # long_str\n",
    "#                [ 7  3  2  5 13  7  0  0]           # medium\n",
    "#                [12  5  8 14  0  0  0  0]]          # tiny\n",
    "# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "perm_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],\n",
       "        [ 7,  3,  2,  5, 13,  7,  0,  0],\n",
       "        [12,  5,  8, 14,  0,  0,  0,  0]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 6: Embed the instances ##\n",
    "##-----------------------------##\n",
    "\n",
    "embedded_seq_tensor = embed(seq_tensor)\n",
    "# embedded_seq_tensor =>\n",
    "#                       [[[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l\n",
    "#                         [-0.23622951  2.0361056   0.15435742 -0.04513785]     o\n",
    "#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n\n",
    "#                         [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     g\n",
    "#                         [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     _\n",
    "#                         [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     s\n",
    "#                         [ 0.64004815  0.45813003  0.3476034  -0.03451729]     t\n",
    "#                         [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    r\n",
    "\n",
    "#                        [[ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m\n",
    "#                         [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     e\n",
    "#                         [ 0.48159844 -1.4886451   0.92639893  0.76906884]     d\n",
    "#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i\n",
    "#                         [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     u\n",
    "#                         [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m\n",
    "#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>\n",
    "#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]    <pad>\n",
    "\n",
    "#                        [[ 0.64004815  0.45813003  0.3476034  -0.03451729]     t\n",
    "#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i\n",
    "#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n\n",
    "#                         [-1.284392    0.68294704  1.4064184  -0.42879772]     y\n",
    "#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>\n",
    "#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>\n",
    "#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>\n",
    "#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]]   <pad>\n",
    "# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.4689,  0.1166,  2.3878, -0.1947],\n",
       "         [ 1.6312, -0.8052, -0.9454,  0.9876],\n",
       "         [ 0.7280,  1.6373,  0.8320, -0.1278],\n",
       "         [-2.4629, -1.1885,  0.7875, -1.4379],\n",
       "         [-1.3570,  0.4596, -1.5374, -0.6518],\n",
       "         [-0.3004, -0.5371, -0.4255,  0.3302],\n",
       "         [-1.2131, -0.6293,  1.1262, -0.9834],\n",
       "         [-1.1551,  0.4608, -1.0322,  0.3843]],\n",
       "\n",
       "        [[-0.7203, -2.1817,  0.5739,  0.8252],\n",
       "         [ 0.0346, -0.4814, -0.0607,  1.2942],\n",
       "         [ 0.0693,  1.4610,  0.6951,  1.5176],\n",
       "         [ 0.9889, -0.2478,  1.5609, -1.5656],\n",
       "         [ 0.9400,  1.7982,  0.1903, -0.3458],\n",
       "         [-0.7203, -2.1817,  0.5739,  0.8252],\n",
       "         [-0.4741, -0.9010,  0.2339, -0.3023],\n",
       "         [-0.4741, -0.9010,  0.2339, -0.3023]],\n",
       "\n",
       "        [[-1.2131, -0.6293,  1.1262, -0.9834],\n",
       "         [ 0.9889, -0.2478,  1.5609, -1.5656],\n",
       "         [ 0.7280,  1.6373,  0.8320, -0.1278],\n",
       "         [-1.4483, -1.7291, -1.2276, -0.6637],\n",
       "         [-0.4741, -0.9010,  0.2339, -0.3023],\n",
       "         [-0.4741, -0.9010,  0.2339, -0.3023],\n",
       "         [-0.4741, -0.9010,  0.2339, -0.3023],\n",
       "         [-0.4741, -0.9010,  0.2339, -0.3023]]], grad_fn=<EmbeddingBackward>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedded_seq_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ##\n",
    "##-------------------------------------------------------------------------------##\n",
    "\n",
    "packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)\n",
    "# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes\n",
    "#\n",
    "# packed_input.data =>\n",
    "#                         [[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l\n",
    "#                          [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     m\n",
    "#                          [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     t\n",
    "#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     o\n",
    "#                          [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     e\n",
    "#                          [-1.284392    0.68294704  1.4064184  -0.42879772]     i\n",
    "#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n\n",
    "#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     d\n",
    "#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n\n",
    "#                          [-0.23622951  2.0361056   0.15435742 -0.04513785]     g\n",
    "#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     i\n",
    "#                          [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    y\n",
    "#                          [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     _\n",
    "#                          [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     u\n",
    "#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     s\n",
    "#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     m\n",
    "#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     t\n",
    "#                          [ 0.48159844 -1.4886451   0.92639893  0.76906884]     r\n",
    "# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4)\n",
    "#\n",
    "# packed_input.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1]\n",
    "# visualization :\n",
    "# l  o  n  g  _  s  t  r   #(long_str)\n",
    "# m  e  d  i  u  m         #(medium)\n",
    "# t  i  n  y               #(tiny)\n",
    "# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([18, 4])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "packed_input.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 8: Forward with LSTM ##\n",
    "##---------------------------##\n",
    "\n",
    "packed_output, (ht, ct) = lstm(packed_input)\n",
    "# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes\n",
    "#\n",
    "# packed_output.data :\n",
    "#                          [[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l\n",
    "#                           [ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   o\n",
    "#                           [ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   n\n",
    "#                           [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   g\n",
    "#                           [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   _\n",
    "#                           [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   s\n",
    "#                           [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   t\n",
    "#                           [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   r\n",
    "#                           [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   m\n",
    "#                           [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   e\n",
    "#                           [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   d\n",
    "#                           [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   i\n",
    "#                           [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   u\n",
    "#                           [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   m\n",
    "#                           [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   t\n",
    "#                           [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   i\n",
    "#                           [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   n\n",
    "#                           [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  y\n",
    "# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)\n",
    "\n",
    "# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)\n",
    "# visualization :\n",
    "# l  o  n  g  _  s  t  r   #(long_str)\n",
    "# m  e  d  i  u  m         #(medium)\n",
    "# t  i  n  y               #(tiny)\n",
    "# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([18, 5])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "packed_output.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.1419,  0.0881, -0.0310, -0.0113,  0.0953],\n",
       "         [ 0.1500, -0.0534, -0.1781, -0.3375,  0.0809],\n",
       "         [ 0.2166,  0.0640, -0.0107, -0.1845,  0.2264]]],\n",
       "       grad_fn=<StackBackward>)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ht"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.3787,  0.1679, -0.0931, -0.0300,  0.1808],\n",
       "         [ 0.2788, -0.1111, -0.5130, -0.6502,  0.1273],\n",
       "         [ 0.3865,  0.1237, -0.0375, -0.4168,  0.3590]]],\n",
       "       grad_fn=<StackBackward>)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ##\n",
    "##------------------------------------------------------------------------------------##\n",
    "\n",
    "# unpack your output if required\n",
    "output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)\n",
    "# output:\n",
    "# output =>\n",
    "#                          [[[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l\n",
    "#                            [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o\n",
    "#                            [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n\n",
    "#                            [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g\n",
    "#                            [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _\n",
    "#                            [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s\n",
    "#                            [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t\n",
    "#                            [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r\n",
    "\n",
    "#                           [[ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m\n",
    "#                            [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e\n",
    "#                            [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d\n",
    "#                            [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i\n",
    "#                            [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u\n",
    "#                            [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m\n",
    "#                            [ 0.          0.          0.          0.          0.        ]   <pad>\n",
    "#                            [ 0.          0.          0.          0.          0.        ]]  <pad>\n",
    "\n",
    "#                           [[ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t\n",
    "#                            [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i\n",
    "#                            [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n\n",
    "#                            [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y\n",
    "#                            [ 0.          0.          0.          0.          0.        ]   <pad>\n",
    "#                            [ 0.          0.          0.          0.          0.        ]   <pad>\n",
    "#                            [ 0.          0.          0.          0.          0.        ]   <pad>\n",
    "#                            [ 0.          0.          0.          0.          0.        ]]] <pad>\n",
    "# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.1314,  0.1148, -0.1547, -0.0988,  0.1985],\n",
       "         [ 0.1038, -0.0806, -0.0241, -0.4552, -0.0479],\n",
       "         [ 0.2337,  0.0522, -0.0998, -0.1697,  0.1013],\n",
       "         [ 0.1577,  0.1131, -0.1312, -0.0600,  0.3258],\n",
       "         [ 0.2132,  0.1747, -0.0052, -0.0139,  0.2000],\n",
       "         [ 0.1591,  0.0041, -0.0087, -0.1207,  0.0698],\n",
       "         [ 0.1732,  0.1159, -0.0839, -0.0802,  0.2733],\n",
       "         [ 0.1419,  0.0881, -0.0310, -0.0113,  0.0953]],\n",
       "\n",
       "        [[-0.0932, -0.0520, -0.0876, -0.2888, -0.0819],\n",
       "         [-0.0170, -0.1638, -0.1081, -0.3362, -0.2294],\n",
       "         [ 0.0853, -0.0355, -0.2138, -0.1198, -0.1966],\n",
       "         [ 0.3120,  0.0238, -0.1533, -0.1364,  0.0514],\n",
       "         [ 0.2446, -0.0162, -0.1131, -0.0867,  0.1285],\n",
       "         [ 0.1500, -0.0534, -0.1781, -0.3375,  0.0809],\n",
       "         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],\n",
       "\n",
       "        [[ 0.0523,  0.1086, -0.0756, -0.0563,  0.2486],\n",
       "         [ 0.2768,  0.1651, -0.0704, -0.1376,  0.3154],\n",
       "         [ 0.2434,  0.1325, -0.1023, -0.0901,  0.2692],\n",
       "         [ 0.2166,  0.0640, -0.0107, -0.1845,  0.2264],\n",
       "         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],\n",
       "       grad_fn=<TransposeBackward0>)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.1419,  0.0881, -0.0310, -0.0113,  0.0953],\n",
      "        [ 0.1500, -0.0534, -0.1781, -0.3375,  0.0809],\n",
      "        [ 0.2166,  0.0640, -0.0107, -0.1845,  0.2264]],\n",
      "       grad_fn=<SelectBackward>)\n"
     ]
    }
   ],
   "source": [
    "# Or if you just want the final hidden state?\n",
    "print(ht[-1])\n",
    "\n",
    "## Summary of Shape Transformations ##\n",
    "##----------------------------------##\n",
    "\n",
    "# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)\n",
    "# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)\n",
    "# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)\n",
    "# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
